# Standard libraries
import os
import re
import sys
import pickle
import datetime
import warnings
from itertools import combinations
from tqdm import tqdm

# Data handling
import numpy as np
import pandas as pd

# Plotting
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

# Statistical modeling
import scipy.stats as stats
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.tsa.arima.model import ARIMA
from pmdarima.arima import auto_arima, ndiffs, nsdiffs

import warnings
# Suppress FutureWarnings globally, including for sklearn and specific pmdarima messages
os.environ['PYTHONWARNINGS'] = 'ignore::FutureWarning'
sys.warnoptions = []
warnings.simplefilter('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=FutureWarning, module='sklearn')
warnings.filterwarnings('ignore', message=r"m \(\d+\) set for non-seasonal fit. Setting to 0", module='pmdarima')
# Suppress warnings globally
warnings.filterwarnings("ignore")

# --- Paths & Constants ---
its_path_main = "./Replication_Materials/data/"
its_path_result = "./Results/"

# --- Set Intervention Date ---
policy_week = datetime.datetime(2018, 8, 10)
policy_index = 53

# --- Load Week Index Mapping ---
with open(its_path_main + "week_breaks_index.pkl", "rb") as f:
    week_breaks_index = pickle.load(f)

# --- Load Data ---
summary_df_domains = pd.read_csv(its_path_main + "summary_df_domains.csv")
summary_df_hashtags = pd.read_csv(its_path_main + "summary_df_hashtags.csv")
summary_df_identity = pd.read_csv(its_path_main + "summary_df_identity.csv")
summary_df_hate = pd.read_csv(its_path_main + "summary_df_hate.csv")

# --- ITS Data Preparation ---
def get_users_agg_data_its(users_data_agg, policy_week):
    """
    Prepare user data for Interrupted Time Series (ITS) analysis, without altering the original data.

    Parameters:
    users_data_agg (DataFrame): DataFrame with user data including 'week_old'.
    policy_week (int): The policy intervention week as a date.

    Returns:
    DataFrame: The modified DataFrame ready for ITS analysis, indexed by 'week_start_date'.
    """
    # Create a copy of the DataFrame to avoid modifying the original
    data = users_data_agg.copy()

    # Map week_old to week_start_date using a predefined dictionary week_breaks_index
    data['week_start_date'] = data['week_old'].map(week_breaks_index)

    # Create a binary intervention column
    data['intervention'] = (data['week_start_date'] >= policy_week).astype(int)

    # Offset week_old by 1 for analysis
    data['week'] = data['week_old'] + 1

    # Calculate the weeks since intervention
    data['intervention_week'] = (data['week_start_date'] >= policy_week).cumsum()

    # Set 'week_start_date' as the index without altering the original DataFrame
    data.set_index('week_start_date', inplace=True)

    # Select only the necessary columns for output
    return data[['y', 'week', 'intervention', 'intervention_week']]

# --- OLS Model for ITS ---
def do_ols_its(data):

    model = smf.ols(formula='y ~ week + intervention + intervention_week', data=data)
    res = model.fit()
   
    return res

# --- ARIMA Model ---
def do_arima(data, ols_results, m=None):
    
    d = max(ndiffs(ols_results.resid, max_d=10),
        ndiffs(ols_results.resid, max_d=10,test='adf'),
        ndiffs(ols_results.resid, max_d=10,test='pp'))
    D = max(nsdiffs(ols_results.resid, max_D=10,m=m),
        nsdiffs(ols_results.resid, max_D=10,m=m,test='ch'))
    
    in_data = data['y']
    xreg = data[['week', 'intervention', 'intervention_week']]
    
    out_model = auto_arima(in_data, n_jobs=-1, d=d, D=D, X=xreg, stepwise=False, 
                           seasonal=False, random_state = 1, 
                           method='powell', m=m, suppress_warnings=True) 
    # print(out_model.summary())
               
    return out_model

# --- Model Plot ---
def do_arima_plot(data, arima_results, policy_week, name, m=None):
    warnings.simplefilter(action='ignore', category=RuntimeWarning)

    start = data[data.index > policy_week].week.values[0]
    end = data.tail(1).week.values[0]

    y_pred = pd.DataFrame(arima_results.arima_res_.fittedvalues, columns=['y'])

    # Counterfactual model
    n_diffs = ndiffs(data[data.index <= policy_week]['y'], test='adf')
    n_Diffs = nsdiffs(data[data.index <= policy_week]['y'], m=m, max_D=3, test='ch')

    arima_cf = auto_arima(
        data[data.index <= policy_week]['y'],
        seasonal=False,
        m=m,
        d=n_diffs,
        D=n_Diffs,
        start_p=0,
        start_q=0,
        max_p=5,
        max_q=5,
        start_P=0,
        start_Q=0,
        max_P=5,
        max_Q=5,
        trace=False,
        error_action='ignore',
        suppress_warnings=True,
        n_jobs=-1,
        stepwise=False
    )

    nsim = 1000
    sim_1000, ci_lower_1000, ci_upper_1000 = [], [], []

    for _ in range(nsim):
        tmp_y_cf, tmp_confint = arima_cf.predict(
            n_periods=len(data[data.index > policy_week]['y']),
            return_conf_int=True, type='levels',
            index=data[data.index > policy_week].index
        )
        sim_1000.append(tmp_y_cf)
        ci_lower_1000.append(list(tmp_confint[:, 0]))
        ci_upper_1000.append(list(tmp_confint[:, 1]))

    sim_ci = np.zeros((nsim, data[data.index > policy_week].shape[0]))
    for i in range(nsim):
        sim_ci[i, :] = arima_cf.arima_res_.simulate(
            nsimulations=data[data.index > policy_week].shape[0],
            anchor='end'
        )

    mean_simulation = np.mean(sim_ci, axis=0)

    confidence_level = 0.95
    lower_percentile = 100 * (1 - confidence_level) / 2
    upper_percentile = 100 * (1 - (1 - confidence_level) / 2)
    lower_confidence = np.percentile(sim_ci, lower_percentile, axis=0)
    upper_confidence = np.percentile(sim_ci, upper_percentile, axis=0)

    y_cf = np.mean(sim_1000, axis=0)
    ci_lower = np.mean(ci_lower_1000, axis=0)
    ci_upper = np.mean(ci_upper_1000, axis=0)

    # Plotting
    # plt.style.use('seaborn-whitegrid')
    fig, ax = plt.subplots(figsize=(30, 8))

    ax.plot(
        data[data.index <= policy_week]["week"][1:],
        y_pred[y_pred.index <= policy_week]["y"][1:],
        color='black',
        linestyle='--',
        linewidth=3,
        label='Model prediction'
    )

    ax.plot(
        data["week"],
        data["y"],
        color='gray',
        linewidth=3,
        label=f'Observed {name}'
    )

    ax.plot(
        data[data.index > policy_week]["week"],
        y_cf,
        color='dimgray',
        marker='.',
        linestyle='None',
        markersize=10,
        label="Counterfactual"
    )

    ax.axvline(
        x=policy_index + 1.5,
        color='darkgray',
        linewidth=3,
        label='Proud Boys Ban'
    )

    # X-axis tick handling
    first_weekly_dates_of_month = data.index[data.index.to_series().dt.to_period("M") != data.index.to_series().dt.to_period("M").shift()]
    show_x_ticklabels = np.array(data[data.index.isin(first_weekly_dates_of_month)].week.tolist())
    ax.xaxis.set_ticks(show_x_ticklabels)
    ax.set_xticklabels(data[data.week.isin(show_x_ticklabels)].index.strftime('%Y-%m').values,
                       rotation=45, fontsize=20)

    # Y-axis
    ax.yaxis.set_tick_params(labelsize=20)

    # Labels and legend
    ax.legend(loc='upper left', fontsize=20)
    ax.set_ylabel(name, fontsize=20)
    ax.title.set_size(30)

    fig.tight_layout()
    plt.savefig(its_path_result + f'its_results/its_plot_{name}.png')
    plt.show()


# -- Calculate Effect Size ---
def calculate_effect_size(data, arima_results, nsim, policy_week, name,m=None):

    # Counterfactual model, mean and 95% confidence interval
    n_diffs = ndiffs(data[data.index <= policy_week]['y'], test='adf')
    n_Diffs = nsdiffs(data[data.index <= policy_week]['y'], m=m, max_D=3, test='ch')

    arima_cf = auto_arima(data[data.index <= policy_week]['y'], 
                      seasonal=False, 
                      m=m,
                      d=n_diffs,
                      D=n_Diffs, 
                      start_p=0, 
                      start_q=0, 
                      max_p=5, 
                      max_q=5, 
                      start_P=0, 
                      start_Q=0, 
                      max_P=5, 
                      max_Q=5, 
                      trace=False, 
                      error_action='ignore', 
                      suppress_warnings=True,
                      n_jobs =-1,
                      stepwise=False)
    
    sim = []
    sim_1000 = []
    counter = 0 
    while counter < nsim:
        sim.append(np.sum(arima_cf.arima_res_.simulate(data[data.index > policy_week].shape[0], 
                                             anchor='end')) / data[data.index > policy_week].shape[0])
        sim_1000.append(arima_cf.arima_res_.simulate(data[data.index > policy_week].shape[0], 
                                             anchor='end'))
        counter += 1
              
    y_effect_mean_diff = np.mean(data[data.index > policy_week]['y']) - np.mean(sim)
    y_effect_std = np.std(sim) 
    y_effect_lower_ci = y_effect_mean_diff-1.96*y_effect_std
    y_effect_upper_ci = y_effect_mean_diff+1.96*y_effect_std
    
    y_effect_zscore = y_effect_mean_diff / y_effect_std
    y_effect_pvalue = stats.norm.sf(abs(y_effect_zscore))*2
    
    y_effect_prop_diff = np.exp(y_effect_mean_diff)
    y_effect_prop_lower_ci  = np.exp(y_effect_lower_ci)
    y_effect_prop_upper_ci = np.exp(y_effect_upper_ci)
    
        
    effect_table = pd.DataFrame({'Variable': [name],
                                 'Log Mean Difference': [round(y_effect_mean_diff,3)],
                                 'Risk ratio': [round(y_effect_prop_diff, 3)],
                                 'Std': [round(y_effect_std,3)],
                            
                                 'Lower Log CI': [round(y_effect_lower_ci,3)],
                                 'Upper Log CI': [round(y_effect_upper_ci,3)],  
                                 
                                 'Lower CI': [round(y_effect_prop_lower_ci,3)],
                                 'Upper CI': [round(y_effect_prop_upper_ci,3)],  
                                 
                                 'Z-score': [round(y_effect_zscore,3)],          
                                 'P-value': [round(y_effect_pvalue,4)]})
    
    effect_table.to_csv(its_path_result + f"its_results/its_effect_size_{name}.csv", index=False)
    display(effect_table)


# --- Main Analysis Wrapper ---
def analyze_its(data, var_name, policy_week, res_name, set_log = False):
    """
    Performs analysis on domain frequency data.

    Parameters:
    df (DataFrame): Dataframe containing the domain frequency data.
    policy_week (int): The specific week marked as the policy intervention week.

    Returns:
    None: Outputs are saved/plotted directly from within the function.
    """
    df = data.copy()
    
    # Assign weekly frequency of tweets with pb domains to a new column
    if set_log:
        df['y'] = np.log(df[var_name])
    else:
        df['y'] = df[var_name]
    
    # Aggregate user data for Interrupted Time Series (ITS) analysis
    its_data = get_users_agg_data_its(users_data_agg=df, policy_week=policy_week)
    
    # Perform OLS regression
    ols_results = do_ols_its(its_data)
    
    # Perform ARIMA modeling with seasonality adjustments
    arima_results = do_arima(data=its_data, ols_results=ols_results, m=5)
    
    # Plot the ARIMA results
    do_arima_plot(data=its_data, arima_results=arima_results, 
                        policy_week=policy_week, name=res_name, m=5)
    
    # Calculate the effect size
    calculate_effect_size(data=its_data, arima_results=arima_results, 
                        nsim=1000, policy_week=policy_week, name=res_name, m=5)


# -- Get Results ---
analyze_its(data = summary_df_domains, 
            var_name = 'weekly_freq_of_tweets_with_pb_domains', 
            policy_week = policy_week, 
            res_name = 'URLs frequency (log)', 
            set_log = True)

analyze_its(data = summary_df_domains, 
            var_name = 'pb_domain_prop_in_tweets_with_domains', 
            policy_week = policy_week, 
            res_name = 'URLs proportion (as decimal)', 
            set_log = False)

analyze_its(data = summary_df_hashtags, 
            var_name = 'weekly_freq_of_tweets_with_pb_hashtags', 
            policy_week = policy_week, 
            res_name = 'Hashtag frequency (log)',
            set_log = True)

analyze_its(data = summary_df_hashtags, 
            var_name = 'pb_hashtag_prop_in_tweets_with_hashtags', 
            policy_week = policy_week, 
            res_name = 'Hashtag proportion (as decimal)',
            set_log = False)

analyze_its(data = summary_df_hate, 
            var_name = 'weekly_freq_of_tweets_with_pb_hate', 
            policy_week = policy_week, 
            res_name = 'Ideological narrative frequency (log)',
            set_log = True)

analyze_its(data = summary_df_hate, 
            var_name = 'pb_hate_prop_in_tweets', 
            policy_week = policy_week, 
            res_name = 'Ideological narrative proportion (as decimal)',
            set_log = False)

analyze_its(data = summary_df_identity, 
            var_name = 'weekly_freq_of_tweets_with_pb_identity', 
            policy_week = policy_week, 
            res_name = 'Identity frequency (log)',
            set_log = True)

analyze_its(data = summary_df_identity, 
            var_name = 'pb_identity_prop_in_tweets', 
            policy_week = policy_week, 
            res_name = 'Identity proportion (as decimal)',
            set_log = False)




